{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 405 DQN Reinforcement Learning\n",
    "\n",
    "View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/\n",
    "My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n",
    "More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/\n",
    "\n",
    "Dependencies:\n",
    "* torch: 0.1.11\n",
    "* gym: 0.8.1\n",
    "* numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import gym"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2017-06-20 22:23:40,418] Making new env: CartPole-v0\n"
     ]
    }
   ],
   "source": [
    "# Hyper Parameters\n",
    "BATCH_SIZE = 32\n",
    "LR = 0.01                   # learning rate\n",
    "EPSILON = 0.9               # greedy policy\n",
    "GAMMA = 0.9                 # reward discount\n",
    "TARGET_REPLACE_ITER = 100   # target update frequency\n",
    "MEMORY_CAPACITY = 2000\n",
    "env = gym.make('CartPole-v0')\n",
    "env = env.unwrapped\n",
    "N_ACTIONS = env.action_space.n\n",
    "N_STATES = env.observation_space.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self, ):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(N_STATES, 10)\n",
    "        self.fc1.weight.data.normal_(0, 0.1)   # initialization\n",
    "        self.out = nn.Linear(10, N_ACTIONS)\n",
    "        self.out.weight.data.normal_(0, 0.1)   # initialization\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        actions_value = self.out(x)\n",
    "        return actions_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class DQN(object):\n",
    "    def __init__(self):\n",
    "        self.eval_net, self.target_net = Net(), Net()\n",
    "\n",
    "        self.learn_step_counter = 0                                     # for target updating\n",
    "        self.memory_counter = 0                                         # for storing memory\n",
    "        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memory\n",
    "        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)\n",
    "        self.loss_func = nn.MSELoss()\n",
    "\n",
    "    def choose_action(self, x):\n",
    "        x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))\n",
    "        # input only one sample\n",
    "        if np.random.uniform() < EPSILON:   # greedy\n",
    "            actions_value = self.eval_net.forward(x)\n",
    "            action = torch.max(actions_value, 1)[1].data.numpy()[0, 0]     # return the argmax\n",
    "        else:   # random\n",
    "            action = np.random.randint(0, N_ACTIONS)\n",
    "        return action\n",
    "\n",
    "    def store_transition(self, s, a, r, s_):\n",
    "        transition = np.hstack((s, [a, r], s_))\n",
    "        # replace the old memory with new memory\n",
    "        index = self.memory_counter % MEMORY_CAPACITY\n",
    "        self.memory[index, :] = transition\n",
    "        self.memory_counter += 1\n",
    "\n",
    "    def learn(self):\n",
    "        # target parameter update\n",
    "        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:\n",
    "            self.target_net.load_state_dict(self.eval_net.state_dict())\n",
    "        self.learn_step_counter += 1\n",
    "\n",
    "        # sample batch transitions\n",
    "        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)\n",
    "        b_memory = self.memory[sample_index, :]\n",
    "        b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))\n",
    "        b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))\n",
    "        b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))\n",
    "        b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))\n",
    "\n",
    "        # q_eval w.r.t the action in experience\n",
    "        q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)\n",
    "        q_next = self.target_net(b_s_).detach()     # detach from graph, don't backpropagate\n",
    "        q_target = b_r + GAMMA * q_next.max(1)[0]   # shape (batch, 1)\n",
    "        loss = self.loss_func(q_eval, q_target)\n",
    "\n",
    "        self.optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        self.optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dqn = DQN()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Collecting experience...\n",
      "Ep:  201 | Ep_r:  1.59\n",
      "Ep:  202 | Ep_r:  4.18\n",
      "Ep:  203 | Ep_r:  2.73\n",
      "Ep:  204 | Ep_r:  1.97\n",
      "Ep:  205 | Ep_r:  1.18\n",
      "Ep:  206 | Ep_r:  0.86\n",
      "Ep:  207 | Ep_r:  2.88\n",
      "Ep:  208 | Ep_r:  1.63\n",
      "Ep:  209 | Ep_r:  3.91\n",
      "Ep:  210 | Ep_r:  3.6\n",
      "Ep:  211 | Ep_r:  0.98\n",
      "Ep:  212 | Ep_r:  3.85\n",
      "Ep:  213 | Ep_r:  1.81\n",
      "Ep:  214 | Ep_r:  2.32\n",
      "Ep:  215 | Ep_r:  3.75\n",
      "Ep:  216 | Ep_r:  3.53\n",
      "Ep:  217 | Ep_r:  4.75\n",
      "Ep:  218 | Ep_r:  2.4\n",
      "Ep:  219 | Ep_r:  0.64\n",
      "Ep:  220 | Ep_r:  1.15\n",
      "Ep:  221 | Ep_r:  2.3\n",
      "Ep:  222 | Ep_r:  7.37\n",
      "Ep:  223 | Ep_r:  1.25\n",
      "Ep:  224 | Ep_r:  5.02\n",
      "Ep:  225 | Ep_r:  10.29\n",
      "Ep:  226 | Ep_r:  17.54\n",
      "Ep:  227 | Ep_r:  36.2\n",
      "Ep:  228 | Ep_r:  6.61\n",
      "Ep:  229 | Ep_r:  10.04\n",
      "Ep:  230 | Ep_r:  55.19\n",
      "Ep:  231 | Ep_r:  10.03\n",
      "Ep:  232 | Ep_r:  13.25\n",
      "Ep:  233 | Ep_r:  8.75\n",
      "Ep:  234 | Ep_r:  3.83\n",
      "Ep:  235 | Ep_r:  -0.92\n",
      "Ep:  236 | Ep_r:  5.12\n",
      "Ep:  237 | Ep_r:  3.56\n",
      "Ep:  238 | Ep_r:  5.69\n",
      "Ep:  239 | Ep_r:  8.43\n",
      "Ep:  240 | Ep_r:  29.27\n",
      "Ep:  241 | Ep_r:  17.95\n",
      "Ep:  242 | Ep_r:  44.77\n",
      "Ep:  243 | Ep_r:  98.0\n",
      "Ep:  244 | Ep_r:  38.78\n",
      "Ep:  245 | Ep_r:  45.02\n",
      "Ep:  246 | Ep_r:  27.73\n",
      "Ep:  247 | Ep_r:  36.96\n",
      "Ep:  248 | Ep_r:  48.98\n",
      "Ep:  249 | Ep_r:  111.36\n",
      "Ep:  250 | Ep_r:  95.61\n",
      "Ep:  251 | Ep_r:  149.77\n",
      "Ep:  252 | Ep_r:  29.96\n",
      "Ep:  253 | Ep_r:  2.79\n",
      "Ep:  254 | Ep_r:  20.1\n",
      "Ep:  255 | Ep_r:  24.25\n",
      "Ep:  256 | Ep_r:  3074.75\n",
      "Ep:  257 | Ep_r:  1258.49\n",
      "Ep:  258 | Ep_r:  127.39\n",
      "Ep:  259 | Ep_r:  283.46\n",
      "Ep:  260 | Ep_r:  166.96\n",
      "Ep:  261 | Ep_r:  101.71\n",
      "Ep:  262 | Ep_r:  63.45\n",
      "Ep:  263 | Ep_r:  288.94\n",
      "Ep:  264 | Ep_r:  130.49\n",
      "Ep:  265 | Ep_r:  207.05\n",
      "Ep:  266 | Ep_r:  183.71\n",
      "Ep:  267 | Ep_r:  142.75\n",
      "Ep:  268 | Ep_r:  126.53\n",
      "Ep:  269 | Ep_r:  310.79\n",
      "Ep:  270 | Ep_r:  863.2\n",
      "Ep:  271 | Ep_r:  365.12\n",
      "Ep:  272 | Ep_r:  659.52\n",
      "Ep:  273 | Ep_r:  103.98\n",
      "Ep:  274 | Ep_r:  554.83\n",
      "Ep:  275 | Ep_r:  246.01\n",
      "Ep:  276 | Ep_r:  332.23\n",
      "Ep:  277 | Ep_r:  323.35\n",
      "Ep:  278 | Ep_r:  278.71\n",
      "Ep:  279 | Ep_r:  613.6\n",
      "Ep:  280 | Ep_r:  152.21\n",
      "Ep:  281 | Ep_r:  402.02\n",
      "Ep:  282 | Ep_r:  351.4\n",
      "Ep:  283 | Ep_r:  115.87\n",
      "Ep:  284 | Ep_r:  163.26\n",
      "Ep:  285 | Ep_r:  631.0\n",
      "Ep:  286 | Ep_r:  263.47\n",
      "Ep:  287 | Ep_r:  511.21\n",
      "Ep:  288 | Ep_r:  337.18\n",
      "Ep:  289 | Ep_r:  819.76\n",
      "Ep:  290 | Ep_r:  190.83\n",
      "Ep:  291 | Ep_r:  442.98\n",
      "Ep:  292 | Ep_r:  537.24\n",
      "Ep:  293 | Ep_r:  1101.12\n",
      "Ep:  294 | Ep_r:  178.42\n",
      "Ep:  295 | Ep_r:  225.61\n",
      "Ep:  296 | Ep_r:  252.62\n",
      "Ep:  297 | Ep_r:  617.5\n",
      "Ep:  298 | Ep_r:  617.8\n",
      "Ep:  299 | Ep_r:  244.01\n",
      "Ep:  300 | Ep_r:  687.91\n",
      "Ep:  301 | Ep_r:  618.51\n",
      "Ep:  302 | Ep_r:  1405.07\n",
      "Ep:  303 | Ep_r:  456.95\n",
      "Ep:  304 | Ep_r:  340.33\n",
      "Ep:  305 | Ep_r:  502.91\n",
      "Ep:  306 | Ep_r:  441.21\n",
      "Ep:  307 | Ep_r:  255.81\n",
      "Ep:  308 | Ep_r:  403.03\n",
      "Ep:  309 | Ep_r:  229.1\n",
      "Ep:  310 | Ep_r:  308.49\n",
      "Ep:  311 | Ep_r:  165.37\n",
      "Ep:  312 | Ep_r:  153.76\n",
      "Ep:  313 | Ep_r:  442.05\n",
      "Ep:  314 | Ep_r:  229.23\n",
      "Ep:  315 | Ep_r:  128.52\n",
      "Ep:  316 | Ep_r:  358.18\n",
      "Ep:  317 | Ep_r:  319.03\n",
      "Ep:  318 | Ep_r:  381.76\n",
      "Ep:  319 | Ep_r:  199.19\n",
      "Ep:  320 | Ep_r:  418.63\n",
      "Ep:  321 | Ep_r:  223.95\n",
      "Ep:  322 | Ep_r:  222.37\n",
      "Ep:  323 | Ep_r:  405.4\n",
      "Ep:  324 | Ep_r:  311.32\n",
      "Ep:  325 | Ep_r:  184.85\n",
      "Ep:  326 | Ep_r:  1026.71\n",
      "Ep:  327 | Ep_r:  252.41\n",
      "Ep:  328 | Ep_r:  224.93\n",
      "Ep:  329 | Ep_r:  620.02\n",
      "Ep:  330 | Ep_r:  174.54\n",
      "Ep:  331 | Ep_r:  782.45\n",
      "Ep:  332 | Ep_r:  263.79\n",
      "Ep:  333 | Ep_r:  178.63\n",
      "Ep:  334 | Ep_r:  242.84\n",
      "Ep:  335 | Ep_r:  635.43\n",
      "Ep:  336 | Ep_r:  668.89\n",
      "Ep:  337 | Ep_r:  265.42\n",
      "Ep:  338 | Ep_r:  207.81\n",
      "Ep:  339 | Ep_r:  293.09\n",
      "Ep:  340 | Ep_r:  530.23\n",
      "Ep:  341 | Ep_r:  479.26\n",
      "Ep:  342 | Ep_r:  559.77\n",
      "Ep:  343 | Ep_r:  241.39\n",
      "Ep:  344 | Ep_r:  158.83\n",
      "Ep:  345 | Ep_r:  1510.69\n",
      "Ep:  346 | Ep_r:  425.17\n",
      "Ep:  347 | Ep_r:  266.94\n",
      "Ep:  348 | Ep_r:  166.08\n",
      "Ep:  349 | Ep_r:  630.52\n",
      "Ep:  350 | Ep_r:  250.95\n",
      "Ep:  351 | Ep_r:  625.88\n",
      "Ep:  352 | Ep_r:  417.7\n",
      "Ep:  353 | Ep_r:  867.81\n",
      "Ep:  354 | Ep_r:  150.62\n",
      "Ep:  355 | Ep_r:  230.89\n",
      "Ep:  356 | Ep_r:  1017.52\n",
      "Ep:  357 | Ep_r:  190.28\n",
      "Ep:  358 | Ep_r:  396.91\n",
      "Ep:  359 | Ep_r:  305.53\n",
      "Ep:  360 | Ep_r:  131.61\n",
      "Ep:  361 | Ep_r:  387.54\n",
      "Ep:  362 | Ep_r:  298.82\n",
      "Ep:  363 | Ep_r:  207.56\n",
      "Ep:  364 | Ep_r:  248.56\n",
      "Ep:  365 | Ep_r:  589.12\n",
      "Ep:  366 | Ep_r:  179.52\n",
      "Ep:  367 | Ep_r:  130.19\n",
      "Ep:  368 | Ep_r:  1220.84\n",
      "Ep:  369 | Ep_r:  126.35\n",
      "Ep:  370 | Ep_r:  133.31\n",
      "Ep:  371 | Ep_r:  485.81\n",
      "Ep:  372 | Ep_r:  823.4\n",
      "Ep:  373 | Ep_r:  253.26\n",
      "Ep:  374 | Ep_r:  466.06\n",
      "Ep:  375 | Ep_r:  203.27\n",
      "Ep:  376 | Ep_r:  386.5\n",
      "Ep:  377 | Ep_r:  491.02\n",
      "Ep:  378 | Ep_r:  239.45\n",
      "Ep:  379 | Ep_r:  276.93\n",
      "Ep:  380 | Ep_r:  331.98\n",
      "Ep:  381 | Ep_r:  764.79\n",
      "Ep:  382 | Ep_r:  198.29\n",
      "Ep:  383 | Ep_r:  717.18\n",
      "Ep:  384 | Ep_r:  562.15\n",
      "Ep:  385 | Ep_r:  29.44\n",
      "Ep:  386 | Ep_r:  344.95\n",
      "Ep:  387 | Ep_r:  671.87\n",
      "Ep:  388 | Ep_r:  299.81\n",
      "Ep:  389 | Ep_r:  899.76\n",
      "Ep:  390 | Ep_r:  319.04\n",
      "Ep:  391 | Ep_r:  252.11\n",
      "Ep:  392 | Ep_r:  865.62\n",
      "Ep:  393 | Ep_r:  255.64\n",
      "Ep:  394 | Ep_r:  81.74\n",
      "Ep:  395 | Ep_r:  213.13\n",
      "Ep:  396 | Ep_r:  422.33\n",
      "Ep:  397 | Ep_r:  167.47\n",
      "Ep:  398 | Ep_r:  507.34\n",
      "Ep:  399 | Ep_r:  614.0\n"
     ]
    }
   ],
   "source": [
    "\n",
    "print('\\nCollecting experience...')\n",
    "for i_episode in range(400):\n",
    "    s = env.reset()\n",
    "    ep_r = 0\n",
    "    while True:\n",
    "        env.render()\n",
    "        a = dqn.choose_action(s)\n",
    "\n",
    "        # take action\n",
    "        s_, r, done, info = env.step(a)\n",
    "\n",
    "        # modify the reward\n",
    "        x, x_dot, theta, theta_dot = s_\n",
    "        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8\n",
    "        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5\n",
    "        r = r1 + r2\n",
    "\n",
    "        dqn.store_transition(s, a, r, s_)\n",
    "\n",
    "        ep_r += r\n",
    "        if dqn.memory_counter > MEMORY_CAPACITY:\n",
    "            dqn.learn()\n",
    "            if done:\n",
    "                print('Ep: ', i_episode,\n",
    "                      '| Ep_r: ', round(ep_r, 2))\n",
    "\n",
    "        if done:\n",
    "            break\n",
    "        s = s_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
